"""
Reference: https://github.com/XiaoxinHe/TAPE/blob/main/core/data_utils/load_cora.py
"""

import os
import sys
sys.path.append('../')

import numpy as np
import torch
import random

from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from utils import init_random_seed
from torch_geometric.utils import to_undirected, add_remaining_self_loops

# return cora dataset as pytorch geometric Data object together with 60/20/20 split, and list of cora IDs

cora_mapping = {
    0: "Case Based",
    1: "Genetic Algorithms",
    2: "Neural Networks",
    3: "Probabilistic Methods",
    4: "Reinforcement Learning",
    5: "Rule Learning",
    6: "Theory"
}


def get_cora_casestudy(SEED=0, dataset_folder="/data/shared/graph_datasets_backup/"):
    data_X, data_Y, data_citeid, data_edges = parse_cora(dataset_folder)

    init_random_seed(SEED)

    dataset = Planetoid(dataset_folder, 'cora',
                        transform=T.NormalizeFeatures())
    data = dataset[0]

    data.x = torch.tensor(data_X).float()
    data.edge_index = torch.tensor(data_edges).long()
    data.y = torch.tensor(data_Y).long()
    data.num_nodes = len(data_Y)

    # split data
    node_id = np.arange(data.num_nodes)
    np.random.shuffle(node_id)

    data.train_id = np.sort(node_id[:int(data.num_nodes * 0.6)])
    data.val_id = np.sort(
        node_id[int(data.num_nodes * 0.6):int(data.num_nodes * 0.8)])
    data.test_id = np.sort(node_id[int(data.num_nodes * 0.8):])

    data.train_mask = torch.tensor(
        [x in data.train_id for x in range(data.num_nodes)])
    data.val_mask = torch.tensor(
        [x in data.val_id for x in range(data.num_nodes)])
    data.test_mask = torch.tensor(
        [x in data.test_id for x in range(data.num_nodes)])

    return data, data_citeid

# credit: https://github.com/tkipf/pygcn/issues/27, xuhaiyun


def parse_cora(dataset_folder):
    path = dataset_folder + 'cora_orig/cora'
    idx_features_labels = np.genfromtxt(
        "{}.content".format(path), dtype=np.dtype(str))
    data_X = idx_features_labels[:, 1:-1].astype(np.float32)
    labels = idx_features_labels[:, -1]
    class_map = {x: i for i, x in enumerate(['Case_Based', 'Genetic_Algorithms', 'Neural_Networks',
                                            'Probabilistic_Methods', 'Reinforcement_Learning', 'Rule_Learning', 'Theory'])}
    data_Y = np.array([class_map[l] for l in labels])
    data_citeid = idx_features_labels[:, 0]
    idx = np.array(data_citeid, dtype=np.dtype(str))
    idx_map = {j: i for i, j in enumerate(idx)}
    edges_unordered = np.genfromtxt(
        "{}.cites".format(path), dtype=np.dtype(str))
    edges = np.array(list(map(idx_map.get, edges_unordered.flatten()))).reshape(
        edges_unordered.shape)
    data_edges = np.array(edges[~(edges == None).max(1)], dtype='int')
    data_edges = np.vstack((data_edges, np.fliplr(data_edges)))
    return data_X, data_Y, data_citeid, np.unique(data_edges, axis=0).transpose()


def get_raw_text(seed=0, dataset_folder="/data/shared/graph_datasets_backup/"):
    data, data_citeid = get_cora_casestudy(seed, dataset_folder)
    # if not use_text:
    #     return data, None, None

    with open(dataset_folder + 'cora_orig/mccallum/cora/papers')as f:
        lines = f.readlines()
    pid_filename = {}
    for line in lines:
        pid = line.split('\t')[0]
        fn = line.split('\t')[1]
        pid_filename[pid] = fn

    path = dataset_folder + 'cora_orig/mccallum/cora/extractions/'

    text = {'title': [], 'content': [], 'label': []}

    for pid in data_citeid:
        # expected_fn = pid_filename[pid].lower()
        fn = pid_filename[pid]
        with open(path+fn) as f:
            lines = f.read().splitlines()

        for line in lines:
            if 'Title:' in line:
                ti = line
            if 'Abstract:' in line:
                ab = line
        text['title'].append(ti)
        text['content'].append(ab)
    
    for i in range(len(data.y)):
        text['label'].append(cora_mapping[data.y[i].item()])
    
    num_classes = 7
    data.edge_index = to_undirected(data.edge_index, data.num_nodes)
    data.edge_index, _ = add_remaining_self_loops(data.edge_index, num_nodes=data.num_nodes)

    return data, text, cora_mapping